// nnet2/train-nnet.h

// Copyright 2012  Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_NNET2_TRAIN_NNET_H_
#define KALDI_NNET2_TRAIN_NNET_H_

#include "nnet2/nnet-update.h"
#include "nnet2/nnet-compute.h"
#include "itf/options-itf.h"

namespace kaldi {
namespace nnet2 {


struct NnetSimpleTrainerConfig {
  int32 minibatch_size;
  int32 minibatches_per_phase;
  
  NnetSimpleTrainerConfig(): minibatch_size(500),
                             minibatches_per_phase(50) { }
  
  void Register (OptionsItf *po) {
    po->Register("minibatch-size", &minibatch_size,
                 "Number of samples per minibatch of training data.");
    po->Register("minibatches-per-phase", &minibatches_per_phase,
                 "Number of minibatches to wait before printing training-set "
                 "objective.");
  }  
};


// Class NnetSimpleTrainer doesn't do much apart from batching up the
// input into minibatches and giving it to the neural net code 
// to call Update(), which will typically do stochastic gradient
// descent.  It also reports training-set
// It takes in the training examples through the call
// "TrainOnExample()".
class NnetSimpleTrainer {
 public:
  NnetSimpleTrainer(const NnetSimpleTrainerConfig &config,
                    Nnet *nnet);
  
  /// TrainOnExample will take the example and add it to a buffer;
  /// if we've reached the minibatch size it will do the training.
  void TrainOnExample(const NnetExample &value);

  ~NnetSimpleTrainer();
 private:
  KALDI_DISALLOW_COPY_AND_ASSIGN(NnetSimpleTrainer);
  
  void TrainOneMinibatch();
  
  // The following function is called by TrainOneMinibatch()
  // when we enter a new phase.
  void BeginNewPhase(bool first_time);
  
  // Things we were given in the initializer:
  NnetSimpleTrainerConfig config_;

  Nnet *nnet_; // the nnet we're training.

  // State information:
  int32 num_phases_;
  int32 minibatches_seen_this_phase_;
  std::vector<NnetExample> buffer_;

  double logprob_this_phase_; // Needed for accumulating train log-prob on each phase.
  double count_this_phase_; // count corresponding to the above.
};



} // namespace nnet2
} // namespace kaldi

#endif
